#' Plot multiple graphs at the same time
#' 
#' Plot multiple graph aligned by rows and columns.
#' 
#' @importFrom data.table data.table
#' @param cols number of columns
#' @return NULL
multiplot <- function(..., cols = 1) {
  plots <- list(...)
  numPlots = length(plots)
  
  layout <- matrix(seq(1, cols * ceiling(numPlots / cols)),
                   ncol = cols, nrow = ceiling(numPlots / cols))
  
  if (numPlots == 1) {
    print(plots[[1]])
  } else {
    grid::grid.newpage()
    grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
    for (i in 1:numPlots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
      
      print(
        plots[[i]], vp = grid::viewport(
          layout.pos.row = matchidx$row,
          layout.pos.col = matchidx$col
        )
      )
    }
  }
}

#' Parse the graph to extract vector of edges
#' @param element igraph object containing the path from the root to the leaf.
edge.parser <- function(element) {
  edges.vector <- igraph::as_ids(element)
  t <- tail(edges.vector, n = 1)
  l <- length(edges.vector)
  list(t,l)
}

#' Extract path from root to leaf from data.table
#' @param dt.tree data.table containing the nodes and edges of the trees
get.paths.to.leaf <- function(dt.tree) {
  dt.not.leaf.edges <-
    dt.tree[Feature != "Leaf",.(ID, Yes, Tree)] %>% list(dt.tree[Feature != "Leaf",.(ID, No, Tree)]) %>% rbindlist(use.names = F)
  
  trees <- dt.tree[,unique(Tree)]
  
  paths <- list()
  for (tree in trees) {
    graph <-
      igraph::graph_from_data_frame(dt.not.leaf.edges[Tree == tree])
    paths.tmp <-
      igraph::shortest_paths(graph, from = paste0(tree, "-0"), to = dt.tree[Tree == tree &
                                                                              Feature == "Leaf", c(ID)])
    paths <- c(paths, paths.tmp$vpath)
  }
  paths
}

#' Plot model trees deepness
#'
#' Generate a graph to plot the distribution of deepness among trees.
#'
#' @importFrom data.table data.table
#' @importFrom data.table rbindlist
#' @importFrom data.table setnames
#' @importFrom data.table :=
#' @importFrom magrittr %>%
#' @param model dump generated by the \code{xgb.train} function.
#'
#' @return Two graphs showing the distribution of the model deepness.
#'
#' @details
#' Display both the number of \code{leaf} and the distribution of \code{weighted observations}
#' by tree deepness level.
#' 
#' The purpose of this function is to help the user to find the best trade-off to set
#' the \code{max.depth} and \code{min_child_weight} parameters according to the bias / variance trade-off.
#' 
#' See \link{xgb.train} for more information about these parameters.
#'
#' The graph is made of two parts:
#'
#' \itemize{
#'  \item Count: number of leaf per level of deepness;
#'  \item Weighted cover: noramlized weighted cover per leaf (weighted number of instances).
#' }
#'
#' This function is inspired by the blog post \url{http://aysent.github.io/2015/11/08/random-forest-leaf-visualization.html}
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max.depth = 15,
#'                  eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
#'                  min_child_weight = 50)
#'
#' xgb.plot.deepness(model = bst)
#'
#' @export
xgb.plot.deepness <- function(model = NULL) {
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("ggplot2 package is required for plotting the graph deepness.",
         call. = FALSE)
  }
  
  if (!requireNamespace("igraph", quietly = TRUE)) {
    stop("igraph package is required for plotting the graph deepness.",
         call. = FALSE)
  }
  
  if (!requireNamespace("grid", quietly = TRUE)) {
    stop("grid package is required for plotting the graph deepness.",
         call. = FALSE)
  }
  
  if (class(model) != "xgb.Booster") {
    stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
  }
  
  dt.tree <- xgb.model.dt.tree(model = model)
  
  dt.edge.elements <- data.table()
  paths <- get.paths.to.leaf(dt.tree)
  
  dt.edge.elements <-
    lapply(paths, edge.parser) %>% rbindlist %>% setnames(c("last.edge", "size")) %>%
    merge(dt.tree, by.x = "last.edge", by.y = "ID") %>% rbind(dt.edge.elements)
  
  dt.edge.summuize <-
    dt.edge.elements[, .(.N, Cover = sum(Cover)), size][,Cover:= Cover / sum(Cover)]
  
  p1 <-
    ggplot2::ggplot(dt.edge.summuize) + ggplot2::geom_line(ggplot2::aes(x = size, y = N, group = 1)) +
    ggplot2::xlab("") + ggplot2::ylab("Count") + ggplot2::ggtitle("Model complexity") +
    ggplot2::theme(
      plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
      panel.grid.major.y = ggplot2::element_blank(),
      axis.ticks = ggplot2::element_blank(),
      axis.text.x = ggplot2::element_blank()
    )
  
  p2 <- 
    ggplot2::ggplot(dt.edge.summuize) + ggplot2::geom_line(ggplot2::aes(x =size, y = Cover, group = 1)) + 
    ggplot2::xlab("From root to leaf path length") + ggplot2::ylab("Weighted cover")
  
  multiplot(p1,p2,cols = 1)
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(
  c(
    "Feature", "Count", "ggplot", "aes", "geom_bar", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text", "ID", "Yes", "No", "Tree"
  )
)
